import matplotlib.pyplot as plt
from tqdm.notebook import trange
from utils import test
from data import load_classwise_PMNIST, load_classwise_NMNIST
from model import EchoSpike, simple_out
import numpy as np
import torch
import seaborn as sns
from scipy.signal import savgol_filter
import pickle
from main import Args
from matplotlib import pyplot
pyplot.rcParams['figure.dpi'] = 600
color_list = sns.color_palette('muted')
device = 'cpu'
folder = 'models/'
model_name = folder + 'online_nmnist.pt'
# with open(model_name[:-3] + '_args.pkl', 'rb') as f:
# args = pickle.load(f)
# args.device = device
args = Args()
print(vars(args))
/home/lars/miniconda3/lib/python3.9/site-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning(). warnings.warn(_BETA_TRANSFORMS_WARNING) /home/lars/miniconda3/lib/python3.9/site-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning(). warnings.warn(_BETA_TRANSFORMS_WARNING)
{'model_name': 'test', 'dataset': 'nmnist', 'online': True, 'device': 'cpu', 'recurrency_type': 'none', 'lr': 0.0001, 'epochs': 100, 'augment': True, 'batch_size': 128, 'n_hidden': [200, 200, 200], 'c_y': [2, -1], 'inp_thr': 0.02, 'n_inputs': 2312, 'n_outputs': 10, 'n_time_bins': 10, 'beta': 0.9}
N-MNIST
if args.dataset == 'mnist':
train_loader, train_loader2, test_loader = load_classwise_PMNIST(args.n_time_bins, scale=args.poisson_scale, split_train=True) #load_NMNIST(n_time_bins, batch_size=batch_size)
else:
train_loader, train_loader2, test_loader = load_classwise_NMNIST(args.n_time_bins, split_train=True)
# Plot Example
frames, target = train_loader.next_item(-1)
print(frames.shape, f'Target Digit: {target.item()}')
plt.figure()
if args.dataset == 'mnist':
plt.imshow(frames[0].view(28,28), cmap='gray')
else:
plt.imshow(frames.squeeze().sum(axis=0).view(2,34,34)[0], cmap='gray')
(34, 34, 2)
/home/lars/ownCloud/ETH/Master/Project_2/SNN_CLAPP/data.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). self.y = torch.tensor(y)
torch.Size([10, 1, 2312]) Target Digit: 4
SNN = EchoSpike(args.n_inputs, args.n_hidden, beta=args.beta, recurrency_type=args.recurrency_type).to(device)
# state_dict = torch.load(model_name, map_location=args.device)
# state_dict = {key.replace('clapp', 'layers'):value for key, value in state_dict.items()}
# # overwrite the state dict
# torch.save(state_dict, model_name)
SNN.load_state_dict(torch.load(model_name, map_location=device))
# Load and Plot train loss history
echo_train_losses = torch.load(f'{model_name[:-3]}_loss_hist.pt', map_location=device)
for i in range(echo_train_losses.shape[1]):
plt.plot(np.linspace(0, args.epochs, len(echo_train_losses)), savgol_filter(echo_train_losses[:,i], 99, 1), label=f'Layer {i+1}', color=color_list[i])
plt.ylabel('EchoSpike Loss')
# no y ticks, because it's not really meaningful
plt.yticks([])
plt.xlabel('Epoch')
# plt.title('EchoSpike Loss During Training for Each Layer');
plt.legend();
echo_activation, target_list, echo_losses = test(SNN, test_loader, device, batch_size=args.batch_size)
print(f'EchoSpike loss per layer: {torch.stack(echo_losses).mean(axis=0).numpy()}')
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Input In [4], in <cell line: 1>() ----> 1 echo_activation, target_list, echo_losses = test(SNN, test_loader, device, batch_size=args.batch_size) 2 print(f'EchoSpike loss per layer: {torch.stack(echo_losses).mean(axis=0).numpy()}') File ~/ownCloud/ETH/Master/Project_2/SNN_CLAPP/utils.py:96, in test(net, testloader, device, batch_size) 94 target = [torch.randint(testloader.num_classes, (1,)).item() for _ in range(batch_size)] 95 while True: ---> 96 data, target = testloader.next_item(target, contrastive=(bf==-1)) 97 target_list.append(target) 98 data = data.float().to(device) File ~/ownCloud/ETH/Master/Project_2/SNN_CLAPP/data.py:69, in classwise_loader.next_item(self, target, contrastive) 67 targets = [] 68 for i in indeces: ---> 69 im, t = self.data[i] 70 imgs.append(torch.tensor(im).view(im.shape[0], -1)) 71 targets.append(t) File ~/miniconda3/lib/python3.9/site-packages/tonic/cached_dataset.py:137, in DiskCachedDataset.__getitem__(self, item) 135 file_path = os.path.join(self.cache_path, f"{item}_{copy}.hdf5") 136 try: --> 137 data, targets = load_from_disk_cache(file_path) 138 except (FileNotFoundError, OSError) as _: 139 logging.info( 140 f"Data {item}: {file_path} not in cache, generating it now", 141 stacklevel=2, 142 ) File ~/miniconda3/lib/python3.9/site-packages/tonic/cached_dataset.py:221, in load_from_disk_cache(file_path) 216 data = { 217 key: f[f"{name}/{index}/{key}"][()] 218 for key in f[f"{name}/{index}"].keys() 219 } 220 else: --> 221 data = f[f"{name}/{index}"][()] 222 _list.append(data) 223 if len(data_list) == 1: File h5py/_objects.pyx:54, in h5py._objects.with_phil.wrapper() File h5py/_objects.pyx:55, in h5py._objects.with_phil.wrapper() File ~/miniconda3/lib/python3.9/site-packages/h5py/_hl/dataset.py:790, in Dataset.__getitem__(self, args, new_dtype) 787 return self.fields(names, _prior_dtype=new_dtype)[args] 789 if new_dtype is None: --> 790 new_dtype = self.dtype 791 mtype = h5t.py_create(new_dtype) 793 # === Special-case region references ==== File h5py/_objects.pyx:54, in h5py._objects.with_phil.wrapper() File h5py/_objects.pyx:55, in h5py._objects.with_phil.wrapper() File ~/miniconda3/lib/python3.9/site-packages/h5py/_hl/dataset.py:541, in Dataset.dtype(self) 538 self._cache_props['_fast_reader'] = rdr 539 return rdr --> 541 @property 542 @with_phil 543 def dtype(self): 544 """Numpy dtype representing the datatype""" 545 return self.id.dtype KeyboardInterrupt:
layers = [SNN.layers[0].fc.weight]
for i in range(1, len(SNN.layers)):
layers.append(SNN.layers[i].fc.weight @ layers[-1])
for i in range(len(SNN.layers)):
plt.figure()
plt.title(f'Layer {i}, Forward weights')
plt.imshow(SNN.layers[i].fc.weight.detach())
plt.colorbar()
for lidx, lay in enumerate(layers):
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle(f'Receptive field, Layer {lidx}')
for i in range(3):
if args.dataset == 'mnist':
axs[i].imshow(lay[i].view(28, 28).detach())
else:
axs[i].imshow(lay[i].view(2, 34, 34)[0].detach())
print(len(echo_activation))
hidden_activities_transformed = [[] for i in range(len(args.n_hidden))]
for ca in echo_activation:
for ca_layer in range(len(ca)):
hidden_activities_transformed[ca_layer].append(ca[ca_layer])
for ha_idx in range(len(args.n_hidden)):
hidden_activities_transformed[ha_idx] = torch.stack(hidden_activities_transformed[ha_idx]).reshape(-1, hidden_activities_transformed[ha_idx][0].shape[-1])
target_transformed = torch.stack(target_list).flatten()
print(hidden_activities_transformed[0].shape, target_transformed.shape)
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP
# transform = TSNE()
# transform = PCA()
transform = UMAP()
colors = [color_list[i.int()] for i in target_transformed]
for hat in hidden_activities_transformed:
# Number of Neurons that never spiked during the test set
print(f'{(hat.sum(axis=0) == 0).sum()} dead neurons')
hat_transform = transform.fit_transform(hat.detach().cpu().numpy())
plt.figure(figsize=(8,8))
col = colors
# Plot each digit separately, this makes it easier to color and label them
for i in range(args.n_outputs):
col_indeces = np.argwhere(target_transformed.squeeze() == i).squeeze()
hattt = hat_transform[col_indeces, :]
plt.scatter(hattt[:,0], hattt[:,1], s=6, color=color_list[i], label=i, alpha=0.4)
plt.legend()
79 torch.Size([10112, 200]) torch.Size([10112])
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Input In [6], in <cell line: 14>() 12 from sklearn.decomposition import PCA 13 from sklearn.manifold import TSNE ---> 14 from umap import UMAP 16 # transform = TSNE() 17 # transform = PCA() 18 transform = UMAP() File ~/miniconda3/lib/python3.9/site-packages/umap/__init__.py:2, in <module> 1 from warnings import warn, catch_warnings, simplefilter ----> 2 from .umap_ import UMAP 4 try: 5 with catch_warnings(): File ~/miniconda3/lib/python3.9/site-packages/umap/umap_.py:48, in <module> 41 from umap.spectral import spectral_layout, tswspectral_layout 42 from umap.layouts import ( 43 optimize_layout_euclidean, 44 optimize_layout_generic, 45 optimize_layout_inverse, 46 ) ---> 48 from pynndescent import NNDescent 49 from pynndescent.distances import named_distances as pynn_named_distances 50 from pynndescent.sparse import sparse_named_distances as pynn_sparse_named_distances File ~/miniconda3/lib/python3.9/site-packages/pynndescent/__init__.py:5, in <module> 1 import sys 3 import numba ----> 5 from .pynndescent_ import NNDescent, PyNNDescentTransformer 7 if sys.version_info[:2] >= (3, 8): 8 import importlib.metadata as importlib_metadata File ~/miniconda3/lib/python3.9/site-packages/pynndescent/pynndescent_.py:22, in <module> 12 from scipy.sparse import ( 13 csr_matrix, 14 coo_matrix, (...) 17 issparse, 18 ) 20 import heapq ---> 22 import pynndescent.sparse as sparse 23 import pynndescent.sparse_nndescent as sparse_nnd 24 import pynndescent.distances as pynnd_dist File ~/miniconda3/lib/python3.9/site-packages/pynndescent/sparse.py:519, in <module> 502 else: 503 return float(num_non_zero - num_equal) / num_non_zero 506 @numba.njit( 507 [ 508 "f4(i4[::1],f4[::1],i4[::1],f4[::1])", 509 numba.types.float32( 510 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 511 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 512 numba.types.Array(numba.types.int32, 1, "C", readonly=True), 513 numba.types.Array(numba.types.float32, 1, "C", readonly=True), 514 ), 515 ], 516 fastmath=True, 517 locals={"num_non_zero": numba.types.intp, "num_equal": numba.types.intp}, 518 ) --> 519 def sparse_alternative_jaccard(ind1, data1, ind2, data2): 520 num_equal = fast_intersection_size(ind1, ind2) 521 num_non_zero = ind1.shape[0] + ind2.shape[0] - num_equal File ~/miniconda3/lib/python3.9/site-packages/numba/core/decorators.py:241, in _jit.<locals>.wrapper(func) 239 with typeinfer.register_dispatcher(disp): 240 for sig in sigs: --> 241 disp.compile(sig) 242 disp.disable_compile() 243 return disp File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:965, in Dispatcher.compile(self, sig) 963 with ev.trigger_event("numba:compile", data=ev_details): 964 try: --> 965 cres = self._compiler.compile(args, return_type) 966 except errors.ForceLiteralArg as e: 967 def folded(args, kws): File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:125, in _FunctionCompiler.compile(self, args, return_type) 124 def compile(self, args, return_type): --> 125 status, retval = self._compile_cached(args, return_type) 126 if status: 127 return retval File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:139, in _FunctionCompiler._compile_cached(self, args, return_type) 136 pass 138 try: --> 139 retval = self._compile_core(args, return_type) 140 except errors.TypingError as e: 141 self._failed_cache[key] = e File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:152, in _FunctionCompiler._compile_core(self, args, return_type) 149 flags = self._customize_flags(flags) 151 impl = self._get_implementation(args, {}) --> 152 cres = compiler.compile_extra(self.targetdescr.typing_context, 153 self.targetdescr.target_context, 154 impl, 155 args=args, return_type=return_type, 156 flags=flags, locals=self.locals, 157 pipeline_class=self.pipeline_class) 158 # Check typing error if object mode is used 159 if cres.typing_error is not None and not flags.enable_pyobject: File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler.py:762, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class) 738 """Compiler entry point 739 740 Parameter (...) 758 compiler pipeline 759 """ 760 pipeline = pipeline_class(typingctx, targetctx, library, 761 args, return_type, flags, locals) --> 762 return pipeline.compile_extra(func) File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler.py:460, in CompilerBase.compile_extra(self, func) 458 self.state.lifted = () 459 self.state.lifted_from = None --> 460 return self._compile_bytecode() File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler.py:528, in CompilerBase._compile_bytecode(self) 524 """ 525 Populate and run pipeline for bytecode input 526 """ 527 assert self.state.func_ir is None --> 528 return self._compile_core() File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler.py:494, in CompilerBase._compile_core(self) 492 res = None 493 try: --> 494 pm.run(self.state) 495 if self.state.cr is not None: 496 break File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state) 354 pass_inst = _pass_registry.get(pss).pass_inst 355 if isinstance(pass_inst, CompilerPass): --> 356 self._runPass(idx, pass_inst, state) 357 else: 358 raise BaseException("Legacy pass in use") File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs) 32 @functools.wraps(func) 33 def _acquire_compile_lock(*args, **kwargs): 34 with self: ---> 35 return func(*args, **kwargs) File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state) 309 mutated |= check(pss.run_initialization, internal_state) 310 with SimpleTimer() as pass_time: --> 311 mutated |= check(pss.run_pass, internal_state) 312 with SimpleTimer() as finalize_time: 313 mutated |= check(pss.run_finalizer, internal_state) File ~/miniconda3/lib/python3.9/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state) 272 def check(func, compiler_state): --> 273 mangled = func(compiler_state) 274 if mangled not in (True, False): 275 msg = ("CompilerPass implementations should return True/False. " 276 "CompilerPass with name '%s' did not.") File ~/miniconda3/lib/python3.9/site-packages/numba/core/typed_passes.py:468, in BaseNativeLowering.run_pass(self, state) 466 lower.lower() 467 if not flags.no_cpython_wrapper: --> 468 lower.create_cpython_wrapper(flags.release_gil) 470 if not flags.no_cfunc_wrapper: 471 # skip cfunc wrapper generation if unsupported 472 # argument or return types are used 473 for t in state.args: File ~/miniconda3/lib/python3.9/site-packages/numba/core/lowering.py:297, in BaseLower.create_cpython_wrapper(self, release_gil) 292 if self.genlower: 293 self.context.create_cpython_wrapper(self.library, 294 self.genlower.gendesc, 295 self.env, self.call_helper, 296 release_gil=release_gil) --> 297 self.context.create_cpython_wrapper(self.library, self.fndesc, 298 self.env, self.call_helper, 299 release_gil=release_gil) File ~/miniconda3/lib/python3.9/site-packages/numba/core/cpu.py:191, in CPUContext.create_cpython_wrapper(self, library, fndesc, env, call_helper, release_gil) 187 builder = PyCallWrapper(self, wrapper_module, wrapper_callee, 188 fndesc, env, call_helper=call_helper, 189 release_gil=release_gil) 190 builder.build() --> 191 library.add_ir_module(wrapper_module) File ~/miniconda3/lib/python3.9/site-packages/numba/core/codegen.py:730, in CPUCodeLibrary.add_ir_module(self, ir_module) 728 ll_module.name = ir_module.name 729 ll_module.verify() --> 730 self.add_llvm_module(ll_module) File ~/miniconda3/lib/python3.9/site-packages/numba/core/codegen.py:737, in CPUCodeLibrary.add_llvm_module(self, ll_module) 735 if not config.LLVM_REFPRUNE_PASS: 736 ll_module = remove_redundant_nrt_refct(ll_module) --> 737 self._final_module.link_in(ll_module) File ~/miniconda3/lib/python3.9/site-packages/llvmlite/binding/module.py:174, in ModuleRef.link_in(self, other, preserve) 172 if preserve: 173 other = other.clone() --> 174 link_modules(self, other) File ~/miniconda3/lib/python3.9/site-packages/llvmlite/binding/linker.py:7, in link_modules(dst, src) 5 def link_modules(dst, src): 6 with ffi.OutputString() as outerr: ----> 7 err = ffi.lib.LLVMPY_LinkModules(dst, src, outerr) 8 # The underlying module was destroyed 9 src.detach() File ~/miniconda3/lib/python3.9/site-packages/llvmlite/binding/ffi.py:152, in _lib_fn_wrapper.__call__(self, *args, **kwargs) 150 def __call__(self, *args, **kwargs): 151 with self._lock: --> 152 return self._cfn(*args, **kwargs) KeyboardInterrupt:
from tqdm.notebook import tqdm
def train_out_proj(epochs, batch, out_projs=None, cat=False):
# train output projections from all layers (and no layer)
dataloader = train_loader2
losses_out = []
optimizers = []
print_interval = 40*batch
if out_projs is None:
out_projs = []
out_proj_0 = simple_out(args.n_inputs, args.n_outputs, beta=1.0)
else:
for out_p in out_projs:
out_p.train()
out_p.reset()
out_proj_0 = out_projs[0]
out_projs = out_projs[1:]
optim_0 = torch.optim.Adam(out_proj_0.parameters(), lr=1e-2)
for lay in range(len(SNN.layers)):
if len(out_projs) <= lay:
if cat:
out_projs.append(simple_out(sum(args.n_hidden[:lay+1]) + args.n_inputs, args.n_outputs, beta=1.0))
else:
out_projs.append(simple_out(args.n_hidden[lay], args.n_outputs, beta=1.0))
optimizers.append(torch.optim.Adam(out_projs[lay].parameters(), lr=1e-2))
optimizers[-1].zero_grad()
SNN.eval()
target = batch*[0]
acc = []
correct = (len(SNN.layers) + 1)*[0]
with torch.no_grad():
pbar = tqdm(total=len(dataloader)*epochs)
while len(losses_out)*batch < len(dataloader)*epochs:
data, target = dataloader.next_item(target, contrastive=True)
SNN.reset(0)
logit_lists = [[] for lay in range(len(SNN.layers)+1)]
data = data.squeeze()
for step in range(data.shape[0]):
data_step = data[step].float().to(device)
target = target.to(device)
logits, _, _ = SNN(data_step, 0)
if step == args.n_time_bins-1:
_, logts = out_proj_0(data_step, target)
logit_lists[0] = logts
for lay in range(len(SNN.layers)):
if cat:
_, logts = out_projs[lay](torch.cat([data_step, *logits[:lay+1]], dim=-1), target)
else:
_, logts = out_projs[lay](logits[lay], target)
logit_lists[lay+1] = logts
else:
out_proj_0(data_step, None)
for lay in range(len(SNN.layers)):
if cat:
out_projs[lay](torch.cat([data_step, *logits[:lay+1]], dim=-1), None)
else:
out_projs[lay](logits[lay], None)
preds = [logit_lists[lay].argmax(axis=-1) for lay in range(len(SNN.layers)+1)]
# if pred.max() < 1: print(pred.max())
dL = [preds[lay] == target for lay in range(len(SNN.layers)+1)]
correct = [correct[lay] + dL[lay].sum() for lay in range(len(SNN.layers)+1)]
out_proj_0.reset()
for i, out_proj in enumerate(out_projs):
out_proj.reset()
losses_out.append(torch.tensor([torch.nn.functional.cross_entropy(logit_lists[lay], target.squeeze().long()) for lay in range(len(SNN.layers)+1)], requires_grad=False))
optim_0.step()
optim_0.zero_grad()
for opt in optimizers:
opt.step()
opt.zero_grad()
if len(losses_out)*batch % print_interval == 0:
pbar.write(f'Cross Entropy Loss: {(torch.stack(losses_out)[-print_interval//batch:].sum(dim=0)/(print_interval//batch)).numpy()}\n' +
f'Correct: {100*np.array(correct)/print_interval}%')
acc.append(np.array(correct)/print_interval)
correct = (len(SNN.layers) + 1)*[0]
pbar.update(batch)
return [out_proj_0, *out_projs], np.asarray(acc), torch.stack(losses_out)
with torch.no_grad():
cat = False
# repeat 10 times
test_accs = []
train_accs = []
for i in range(10):
# new random seed
torch.manual_seed(i)
out_projs, acc, losses_out = train_out_proj(1, 30, cat=cat)
test_accs.append(get_accuracy(test_loader, out_projs, cat=cat)[0])
train_accs.append(get_accuracy(train_loader2, out_projs, cat=cat)[0])
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.3382182 0.47337666 0.42744297 0.37736744] Correct: [71.83333333 89.16666667 88.91666667 90.75 ]% Cross Entropy Loss: [1.6546555 0.17299062 0.17895769 0.16855808] Correct: [79.91666667 95.16666667 95.75 96.16666667]% Cross Entropy Loss: [1.2595325 0.15038992 0.16651992 0.16367355] Correct: [84.83333333 95.16666667 95.5 95.83333333]% Cross Entropy Loss: [1.2586384 0.1645219 0.17492877 0.17914307] Correct: [85.58333333 95.25 95.33333333 95.33333333]% Cross Entropy Loss: [1.553818 0.15008847 0.15458283 0.16141796] Correct: [86. 96.16666667 96.25 96.25 ]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 86.06% From layer 1: Accuracy: 95.28% From layer 2: Accuracy: 95.44% From layer 3: Accuracy: 95.25%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 87.02% From layer 1: Accuracy: 96.67% From layer 2: Accuracy: 96.37% From layer 3: Accuracy: 96.00%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.9430552 0.48246732 0.39133295 0.38660818] Correct: [66.25 87. 89.75 90.5 ]% Cross Entropy Loss: [1.4480274 0.19134916 0.201902 0.20016089] Correct: [84.25 94. 94.91666667 94.83333333]% Cross Entropy Loss: [1.0512209 0.13377151 0.1263952 0.12791191] Correct: [85.91666667 96.25 96.83333333 96.58333333]% Cross Entropy Loss: [1.3005296 0.1641307 0.18072578 0.1902675 ] Correct: [85.33333333 95.08333333 95.16666667 95.16666667]% Cross Entropy Loss: [1.1286719 0.1699802 0.1810993 0.17415996] Correct: [85.08333333 95.33333333 95.83333333 95.66666667]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 80.96% From layer 1: Accuracy: 95.27% From layer 2: Accuracy: 95.24% From layer 3: Accuracy: 95.23%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 79.73% From layer 1: Accuracy: 96.52% From layer 2: Accuracy: 96.37% From layer 3: Accuracy: 96.02%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [3.1067567 0.4774762 0.38488543 0.3602371 ] Correct: [66.58333333 88.91666667 90.08333333 91.08333333]% Cross Entropy Loss: [1.0771862 0.15314627 0.16538922 0.16365175] Correct: [85.83333333 95.58333333 95.5 95.58333333]% Cross Entropy Loss: [1.2817585 0.18245208 0.18243818 0.17942084] Correct: [84.83333333 95.08333333 95.33333333 94.91666667]% Cross Entropy Loss: [1.6821384 0.14353484 0.16773058 0.1662912 ] Correct: [81.91666667 95.58333333 95.5 95.25 ]% Cross Entropy Loss: [1.0285895 0.16867764 0.17331538 0.17302595] Correct: [86.5 95.66666667 95.25 95.58333333]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 88.08% From layer 1: Accuracy: 95.70% From layer 2: Accuracy: 95.85% From layer 3: Accuracy: 95.63%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 90.65% From layer 1: Accuracy: 96.98% From layer 2: Accuracy: 96.23% From layer 3: Accuracy: 95.92%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.7217298 0.46089298 0.34992224 0.3748212 ] Correct: [69.33333333 87.58333333 91.16666667 90.25 ]% Cross Entropy Loss: [1.4026086 0.15913293 0.15897939 0.16435106] Correct: [83.16666667 95.25 95.75 95.58333333]% Cross Entropy Loss: [1.5430806 0.15613322 0.15375328 0.15590101] Correct: [82.41666667 95.5 96. 95.91666667]% Cross Entropy Loss: [1.5777922 0.18366615 0.18653242 0.18294129] Correct: [83.08333333 95. 95.25 95.33333333]% Cross Entropy Loss: [1.2560295 0.1496404 0.16298625 0.15777263] Correct: [87.41666667 96.08333333 95.91666667 95.66666667]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 86.61% From layer 1: Accuracy: 94.98% From layer 2: Accuracy: 95.14% From layer 3: Accuracy: 95.21%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 87.82% From layer 1: Accuracy: 96.52% From layer 2: Accuracy: 96.22% From layer 3: Accuracy: 95.83%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.9129624 0.43650824 0.34120786 0.36636773] Correct: [65.08333333 88.83333333 90.41666667 90.41666667]% Cross Entropy Loss: [1.699683 0.1648637 0.16833332 0.16378531] Correct: [82.75 95.5 95.66666667 95.5 ]% Cross Entropy Loss: [1.3655012 0.16319947 0.17740284 0.18934996] Correct: [86. 95.41666667 95.25 95.25 ]% Cross Entropy Loss: [1.4213681 0.19052967 0.20100173 0.2064105 ] Correct: [83.41666667 94.66666667 94.33333333 94.5 ]% Cross Entropy Loss: [1.6758854 0.14381635 0.15897633 0.15392265] Correct: [83.83333333 95.91666667 96.16666667 95.41666667]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 86.64% From layer 1: Accuracy: 95.74% From layer 2: Accuracy: 95.75% From layer 3: Accuracy: 95.69%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 89.78% From layer 1: Accuracy: 96.32% From layer 2: Accuracy: 96.47% From layer 3: Accuracy: 96.17%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.6625247 0.4316801 0.3392229 0.3603861] Correct: [66.41666667 88.75 92.33333333 91.66666667]% Cross Entropy Loss: [1.0696148 0.15356013 0.16448016 0.16801438] Correct: [83.33333333 95.41666667 95.66666667 95.5 ]% Cross Entropy Loss: [1.5793362 0.17409055 0.1730139 0.17770503] Correct: [81.08333333 95. 95.33333333 95.75 ]% Cross Entropy Loss: [1.3155136 0.16008207 0.16914228 0.1606429 ] Correct: [85.5 95.41666667 95.5 95.83333333]% Cross Entropy Loss: [1.4178445 0.14877504 0.16295645 0.16942309] Correct: [85.91666667 95.83333333 95.58333333 95.91666667]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 87.52% From layer 1: Accuracy: 95.72% From layer 2: Accuracy: 95.47% From layer 3: Accuracy: 95.45%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 88.10% From layer 1: Accuracy: 96.65% From layer 2: Accuracy: 96.28% From layer 3: Accuracy: 95.97%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [3.376453 0.42517838 0.35822278 0.33904296] Correct: [67.16666667 89.66666667 90.25 91.5 ]% Cross Entropy Loss: [1.3644907 0.19069616 0.18601438 0.18697791] Correct: [82.33333333 94.41666667 95.08333333 94.75 ]% Cross Entropy Loss: [1.5174477 0.15634212 0.16194287 0.16461536] Correct: [81.91666667 95.16666667 95.91666667 95.75 ]% Cross Entropy Loss: [1.3004358 0.14357959 0.15617082 0.15745345] Correct: [85.75 96.08333333 96. 95.75 ]% Cross Entropy Loss: [1.1230577 0.15044114 0.16561505 0.17040502] Correct: [85.16666667 95.91666667 95.66666667 95.66666667]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 87.96% From layer 1: Accuracy: 95.33% From layer 2: Accuracy: 95.40% From layer 3: Accuracy: 95.34%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 88.35% From layer 1: Accuracy: 96.65% From layer 2: Accuracy: 96.27% From layer 3: Accuracy: 96.05%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [3.0279756 0.46350527 0.37065417 0.36309353] Correct: [66.25 87.66666667 91. 91.16666667]% Cross Entropy Loss: [1.0403515 0.15943103 0.14953628 0.14353053] Correct: [84.83333333 95.58333333 95.83333333 95.83333333]% Cross Entropy Loss: [1.2071066 0.18711491 0.20832297 0.21060458] Correct: [84.75 94.83333333 94.75 94.83333333]% Cross Entropy Loss: [1.4031492 0.15133992 0.155506 0.1607064 ] Correct: [83.66666667 96.25 95.66666667 95.5 ]% Cross Entropy Loss: [1.0277554 0.11540209 0.11871958 0.12798412] Correct: [88. 96.5 96.08333333 96. ]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 82.82% From layer 1: Accuracy: 95.08% From layer 2: Accuracy: 94.26% From layer 3: Accuracy: 94.41%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 88.68% From layer 1: Accuracy: 96.47% From layer 2: Accuracy: 96.20% From layer 3: Accuracy: 95.78%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.7167056 0.44559088 0.37563068 0.37984776] Correct: [70.08333333 89.66666667 90.91666667 91.16666667]% Cross Entropy Loss: [1.6023502 0.16663647 0.17960748 0.18191728] Correct: [81. 94.75 95.08333333 95.25 ]% Cross Entropy Loss: [1.4329418 0.18200654 0.19142093 0.1874464 ] Correct: [82.66666667 95.25 95.41666667 95.41666667]% Cross Entropy Loss: [0.9958844 0.14487785 0.1586114 0.16825038] Correct: [87.58333333 96.33333333 96. 96.16666667]% Cross Entropy Loss: [1.414678 0.1706594 0.17077143 0.18180208] Correct: [85.25 95.25 95. 94.41666667]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 83.47% From layer 1: Accuracy: 95.44% From layer 2: Accuracy: 95.67% From layer 3: Accuracy: 95.50%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 86.95% From layer 1: Accuracy: 96.48% From layer 2: Accuracy: 96.48% From layer 3: Accuracy: 96.22%
0%| | 0/6000 [00:00<?, ?it/s]
Cross Entropy Loss: [2.7415905 0.46612364 0.3127798 0.3401905 ] Correct: [67.83333333 88.66666667 91.41666667 92. ]% Cross Entropy Loss: [1.170493 0.15758887 0.15697098 0.1570035 ] Correct: [85.08333333 95.58333333 96.08333333 96. ]% Cross Entropy Loss: [1.1175559 0.17982589 0.1780278 0.1790441 ] Correct: [85.16666667 95. 95.66666667 95.91666667]% Cross Entropy Loss: [1.1843574 0.17871569 0.18905477 0.18735561] Correct: [84.58333333 94.75 95.33333333 94.58333333]% Cross Entropy Loss: [1.1208203 0.14442256 0.153645 0.15707959] Correct: [86.5 96. 96. 95.75]%
0%| | 0/79 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 87.83% From layer 1: Accuracy: 95.43% From layer 2: Accuracy: 95.71% From layer 3: Accuracy: 95.63%
0%| | 0/47 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 90.67% From layer 1: Accuracy: 96.35% From layer 2: Accuracy: 96.50% From layer 3: Accuracy: 96.22%
print(f'Accuracy of last quarter: {100*acc[-len(acc)//4:].mean(axis=0)}%')
plt.figure()
for i in range(acc.shape[1]):
plt.plot(np.asarray(acc)[:,i]*100, color=color_list[i])
plt.ylabel('Accuracy [%]')
plt.xlabel('Training Step [x500]')
labels = ['From Inputs directly', *[f'From Layer {i+1}' for i in range(len(SNN.layers))]]
plt.legend(labels)
plt.ylim([90, 100])
plt.figure()
for i in range(losses_out.shape[1]):
plt.plot(np.arange(len(losses_out))*args.batch_size/len(train_loader), savgol_filter(losses_out[:,i], 19, 1), label=labels[i], color=color_list[i])
plt.ylabel('Cross Entropy Loss')
plt.xlabel('Epoch')
plt.ylim([0, 0.6])
plt.legend();
Accuracy of last quarter: [86.625 95.9375 95.75 95.75 ]%
def get_accuracy(dataloader, out_projs, cat=False):
correct = torch.zeros(len(out_projs))
for out_proj in out_projs:
out_proj.eval()
SNN.eval()
total = 0
pred_matrix = torch.zeros(args.n_outputs, args.n_outputs)
for idx in trange(0, len(dataloader), args.batch_size):
for out_proj in out_projs:
out_proj.reset()
SNN.reset(0)
if args.dataset == 'mnist':
inp, target = dataloader.x[idx:idx+args.batch_size], dataloader.y[idx:idx+args.batch_size]
else:
flattenend_indeces = torch.cat(dataloader.target_indeces)
indeces = flattenend_indeces[idx:idx+args.batch_size]
until = min(args.batch_size, len(dataloader) - idx)
inp = torch.stack([torch.tensor(dataloader.data[indeces[i]][0]).view(args.n_time_bins, -1) for i in range(until)])
target = torch.tensor([dataloader.data[indeces[i]][1] for i in range(until)])
logits = len(out_projs)*[torch.zeros((inp.shape[0],args.n_outputs))]
for step in range(inp.shape[1]):
data_step = inp[:,step].float().to(device)
spk_step, _, _ = SNN(data_step, 0)
spk_step = [data_step, *spk_step]
for i, out_proj in enumerate(out_projs):
if cat:
_, mem = out_proj(torch.cat(spk_step[:i+1], dim=-1), target)
else:
_, mem = out_proj(spk_step[i], target)
if step == args.n_time_bins-1:
logits[i] = mem
total += inp.shape[0]
for i, logit in enumerate(logits):
pred = logit.argmax(axis=-1)
correct[i] += int((pred == target).sum())
# for the last layer create the prediction matrix
for j in range(pred.shape[0]):
pred_matrix[int(target[j]), int(pred[j])] += 1
correct /= len(dataloader)
assert total == len(dataloader)
print('Directly from inputs:')
print(f'Accuracy: {100*correct[0]:.2f}%')
for i in range(len(out_projs)-1):
print(f'From layer {i+1}:')
print(f'Accuracy: {100*correct[i+1]:.2f}%')
return correct, pred_matrix
correct, pred_matrix = get_accuracy(test_loader, out_projs, cat=cat)
plt.imshow(pred_matrix, origin='lower')
plt.title('Prediction Matrix for the final layer')
plt.xlabel('Prediction')
plt.ylabel('Target')
plt.xticks([i for i in range(args.n_outputs)])
plt.yticks([i for i in range(args.n_outputs)])
plt.colorbar();
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Input In [4], in <cell line: 47>() 45 print(f'Accuracy: {100*correct[i+1]:.2f}%') 46 return correct, pred_matrix ---> 47 correct, pred_matrix = get_accuracy(test_loader, out_projs, cat=cat) 48 plt.imshow(pred_matrix, origin='lower') 49 plt.title('Prediction Matrix for the final layer') NameError: name 'out_projs' is not defined
# train_accs = torch.stack(train_accs)
# test_accs = torch.stack(test_accs)
print(train_accs.shape)
print(f'Train Accuracy: {100*train_accs.mean(axis=0)}%, Std: {100*train_accs.std(axis=0)}%')
print(f'Test Accuracy: {100*test_accs.mean(axis=0)}%, Std: {100*test_accs.std(axis=0)}%')
# grouped Bar plot the Accuracies of the different layers both during training and testing
sns.set_theme(style="whitegrid")
labels = ['From Inputs Directly', *[f'From Layer {i+1}' for i in range(len(SNN.layers))]]
x = torch.arange(len(labels)) # the label locations
width = 0.35 # the width of the bars
fig, ax = plt.subplots()
print(x.shape, train_accs.mean(axis=0).shape)
rects1 = ax.bar(x - width/2, 100*test_accs.mean(axis=0), width, label='Test Accuracy', color=color_list[0])
ax.errorbar(x - width/2, 100*test_accs.mean(axis=0), yerr=100*test_accs.std(axis=0), fmt='none', capsize=6, color=color_list[3])
rects2 = ax.bar(x + width/2, 100*train_accs.mean(axis=0), width, label='Train Accuracy', color=color_list[1])
ax.errorbar(x + width/2, 100*train_accs.mean(axis=0), yerr=100*train_accs.std(axis=0), fmt='none', capsize=6, color=color_list[3])
# remove horizontal lines and spines
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.xaxis.grid(False)
plt.xticks(np.arange(len(out_projs)), labels, rotation=45)
plt.legend()
plt.ylabel('Accuracy [%]')
plt.ylim([40, 100])
torch.Size([10, 4]) Train Accuracy: tensor([87.7750, 96.5600, 96.3383, 96.0167])%, Std: tensor([3.1242, 0.1910, 0.1147, 0.1507])% Test Accuracy: tensor([85.7950, 95.3970, 95.3930, 95.3340])%, Std: tensor([2.4968, 0.2636, 0.4586, 0.3697])% torch.Size([4]) torch.Size([4])
(40.0, 100.0)
n_repeats = 10
fewshot_accuracies = torch.zeros((n_repeats, len(SNN.layers)))
for n in range(n_repeats):
# Randomly select one sample of each class and save the spiking activity
SNN.reset(0)
one_shot_samples = torch.zeros(args.n_outputs, args.n_time_bins, args.n_inputs)
one_shot_spks = [torch.zeros(args.n_outputs, h) for h in args.n_hidden]
k = 20
for i in range(args.n_outputs):
for j in range(k):
img, _ = train_loader2.next_item(i, contrastive=False)
one_shot_samples[i] = img.squeeze()
for t in range(args.n_time_bins):
logits, _, _ = SNN(img[t].float(), 0)
for idx, log in enumerate(logits):
one_shot_spks[idx][i] += log.squeeze()
def metric(spk, one_shot):
dists = torch.zeros(spk.shape[0], args.n_outputs)
for i in range(args.n_outputs):
one_shot_i = one_shot[i] / one_shot[i].sum()
dists[:, i] = torch.einsum('bi, i->b' , spk, one_shot_i)
return dists
def get_predictions(spks):
preds = torch.zeros(len(spks), spks[0].shape[0])
# for each layer get the prediction
for i in range(len(spks)):
dists = metric(spks[i], one_shot_spks[i])
preds[i] = dists.argmax(axis=-1)
return preds
batch = int(len(test_loader)/10)
correct_oneshot = torch.zeros(len(SNN.layers))
SNN.eval()
pred_matrix_oneshot = torch.zeros(args.n_outputs, args.n_outputs)
for idx in trange(0, len(test_loader), batch):
SNN.reset(0)
if args.dataset == 'mnist':
inp, target = test_loader.x[idx:idx+batch], test_loader.y[idx:idx+batch]
else:
until = min(batch, len(test_loader.data) - idx)
inp = torch.stack([torch.tensor(test_loader.data[idx+i][0]).view(args.n_time_bins, -1) for i in range(until)])
target = torch.tensor([test_loader.data[idx+i][1] for i in range(until)])
logits = [torch.zeros(inp.shape[0], h) for h in args.n_hidden]
for step in range(inp.shape[1]):
data_step = inp[:,step].float().to(device)
spk_step, _, _ = SNN(data_step, 0)
for logidx in range(len(spk_step)):
logits[logidx] += spk_step[logidx]
preds = get_predictions(logits)
for i in range(preds.shape[0]):
correct_oneshot[i] += int((preds[i] == target).sum())
# for the last layer create the prediction matrix
for j in range(preds.shape[1]):
pred_matrix_oneshot[int(target[j]), int(preds[-1, j])] += 1
correct_oneshot /= len(test_loader)
for i in range(len(SNN.layers)):
print(f'From layer {i+1}:')
print(f'Accuracy: {100*correct_oneshot[i]:.2f}%')
fewshot_accuracies[n] = correct_oneshot
plt.imshow(pred_matrix_oneshot, origin='lower')
plt.title('Prediction Matrix for the final layer')
plt.xlabel('Prediction')
plt.ylabel('Target')
plt.xticks([i for i in range(args.n_outputs)])
plt.yticks([i for i in range(args.n_outputs)])
plt.colorbar();
plt.show()
print(f'Accuracy per Label: {100*pred_matrix_oneshot.diag()/pred_matrix_oneshot.sum(axis=1)}%') # correct axis?
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 93.77% From layer 2: Accuracy: 95.30% From layer 3: Accuracy: 95.57%
/tmp/ipykernel_1261198/1569276281.py:68: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first. plt.colorbar();
Accuracy per Label: tensor([99.1443, 98.5232, 95.1907, 94.0629, 96.9880, 94.7781, 97.1899, 92.6627,
94.4107, 91.6963])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 94.05% From layer 2: Accuracy: 95.18% From layer 3: Accuracy: 95.33%
Accuracy per Label: tensor([99.2665, 98.6287, 93.2007, 94.0629, 96.8675, 94.7781, 97.3269, 93.0178,
95.0182, 91.8149])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 94.40% From layer 2: Accuracy: 95.32% From layer 3: Accuracy: 95.59%
Accuracy per Label: tensor([99.0220, 98.6287, 95.2460, 93.8300, 96.7470, 94.3864, 97.2584, 92.8994,
95.2612, 91.4591])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 94.08% From layer 2: Accuracy: 95.15% From layer 3: Accuracy: 95.47%
Accuracy per Label: tensor([99.0220, 98.5232, 94.6932, 94.0629, 96.8675, 94.9086, 96.6415, 93.1361,
95.0182, 91.5777])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 93.97% From layer 2: Accuracy: 95.18% From layer 3: Accuracy: 95.38%
Accuracy per Label: tensor([99.0220, 98.5232, 93.3112, 95.1106, 96.7470, 93.6031, 97.4640, 93.7278,
95.1397, 91.5777])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 93.83% From layer 2: Accuracy: 95.11% From layer 3: Accuracy: 95.41%
Accuracy per Label: tensor([98.8998, 98.7342, 94.2510, 93.9464, 96.7470, 94.9086, 96.9842, 92.8994,
94.6537, 91.9336])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 94.25% From layer 2: Accuracy: 95.25% From layer 3: Accuracy: 95.68%
Accuracy per Label: tensor([98.4108, 98.6287, 95.0249, 94.1793, 96.8675, 94.9086, 97.4640, 94.3195,
94.5322, 91.5777])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 93.02% From layer 2: Accuracy: 94.84% From layer 3: Accuracy: 95.23%
Accuracy per Label: tensor([99.1443, 98.4177, 92.9243, 93.8300, 96.6265, 94.9086, 97.1899, 93.2544,
95.6258, 91.3405])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 93.58% From layer 2: Accuracy: 95.04% From layer 3: Accuracy: 95.32%
Accuracy per Label: tensor([99.1443, 98.5232, 93.4771, 94.2957, 96.7470, 95.1697, 97.1899, 93.1361,
94.2892, 91.6963])%
0%| | 0/10 [00:00<?, ?it/s]
From layer 1: Accuracy: 94.04% From layer 2: Accuracy: 95.08% From layer 3: Accuracy: 95.38%
Accuracy per Label: tensor([98.8998, 98.6287, 93.8640, 93.7136, 97.2289, 94.9086, 97.0528, 93.7278,
94.6537, 91.3405])%
# Boxplot of the accuracies
plt.figure()
sns.set_style("whitegrid")
g = sns.boxplot(data=fewshot_accuracies*100)
# remove left spines
sns.despine(left=True)
plt.xticks(np.arange(len(SNN.layers)), [f'Layer {i+1}' for i in range(len(SNN.layers))])
plt.ylabel('Few-Shot Test Accuracy [%]')
plt.ylim([90, 100])
print(f'Average Accuracy: {100*fewshot_accuracies.mean(axis=0)}%')
print(f'Maximum Accuracy: {fewshot_accuracies.max(axis=0)}%')
Average Accuracy: tensor([93.8990, 95.1450, 95.4360])% Maximum Accuracy: torch.return_types.max( values=tensor([0.9440, 0.9532, 0.9568]), indices=tensor([2, 2, 6]))%